{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 准备数据\n",
    "\n",
    "在tensorflow中准备图片数据的常用方案有两种，第一种是使用tf.keras中的ImageDataGenerator工具构建图片数据生成器。\n",
    "\n",
    "第二种是使用tf.data.Dataset搭配tf.image中的一些图片处理方法构建数据管道。\n",
    "第一种方法更为简单，其使用范例可以参考以下文章。\n",
    "\n",
    "https://zhuanlan.zhihu.com/p/67466552"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.keras import datasets,layers,models\n",
    "BATCH_SIZE=100\n",
    "def load_image(img_path,size=(32,32)):\n",
    "    label = tf.constant(1,tf.int8) if tf.strings.regex_full_match(img_path,\".*automobile.*\") \\\n",
    "            else tf.constant(0,tf.int8)\n",
    "    img = tf.io.read_file(img_path)\n",
    "    img = tf.image.decode_jpeg(img) #注意此处为jpeg格式\n",
    "    img = tf.image.resize(img,size)/255.0\n",
    "    return(img,label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#使用并行化预处理num_parallel_calls 和预存数据prefetch来提升性能\n",
    "ds_train = tf.data.Dataset.list_files(\"./data/cifar2/train/*/*.jpg\") \\\n",
    "           .map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE) \\\n",
    "           .shuffle(buffer_size = 1000).batch(BATCH_SIZE) \\\n",
    "           .prefetch(tf.data.experimental.AUTOTUNE)  \n",
    "\n",
    "ds_test = tf.data.Dataset.list_files(\"./data/cifar2/test/*/*.jpg\") \\\n",
    "           .map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE) \\\n",
    "           .batch(BATCH_SIZE) \\\n",
    "           .prefetch(tf.data.experimental.AUTOTUNE)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'svg'\n",
    "\n",
    "#查看部分样本\n",
    "from matplotlib import pyplot as plt \n",
    "\n",
    "plt.figure(figsize=(8,8)) \n",
    "for i,(img,label) in enumerate(ds_train.unbatch().take(9)):\n",
    "    ax=plt.subplot(3,3,i+1)\n",
    "    ax.imshow(img.numpy())\n",
    "    ax.set_title(\"label = %d\"%label)\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([]) \n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for x,y in ds_train.take(1):\n",
    "    print(x.shape,y.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义模型 \n",
    "使用函数式API"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = layers.Input(shape=(32,32,3))\n",
    "x = layers.Conv2D(32,kernel_size=(3,3))(inputs)\n",
    "x = layers.MaxPool2D()(x)\n",
    "x = layers.Conv2D(64,kernel_size=(5,5))(x)\n",
    "x = layers.MaxPool2D()(x)\n",
    "x = layers.Dropout(rate=0.1)(x)\n",
    "x = layers.Flatten()(x)\n",
    "x = layers.Dense(32,activation='relu')(x)\n",
    "outputs = layers.Dense(1,activation = 'sigmoid')(x)\n",
    "\n",
    "model = models.Model(inputs = inputs,outputs = outputs)\n",
    "\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datetime\n",
    "import os\n",
    "\n",
    "stamp = datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
    "logdir = os.path.join('data', 'autograph', stamp)\n",
    "\n",
    "## 在 Python3 下建议使用 pathlib 修正各操作系统的路径\n",
    "# from pathlib import Path\n",
    "# stamp = datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
    "# logdir = str(Path('./data/autograph/' + stamp))\n",
    "\n",
    "tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)\n",
    "\n",
    "model.compile(\n",
    "        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),\n",
    "        loss=tf.keras.losses.binary_crossentropy,\n",
    "        metrics=[\"accuracy\"]\n",
    "    )\n",
    "\n",
    "history = model.fit(ds_train,epochs= 10,validation_data=ds_test,\n",
    "                    callbacks = [tensorboard_callback],workers = 4)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
